In [1]:
# Define experiment parameters
year = "201516"
target_col = "blue_collar"  # 'white_collar', 'blue_collar', 'has_occ'
sample_weight_col = 'women_weight'
In [2]:
# Define resource utilization parameters
random_state = 42
n_jobs_clf = 16
n_jobs_cv = 4
cv_folds = 5
In [3]:
import numpy as np
np.random.seed(random_state)

import pandas as pd
pd.set_option('display.max_columns', 500)

import matplotlib.pylab as pl

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

from sklearn.utils.class_weight import compute_class_weight

import lightgbm
from lightgbm import LGBMClassifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import StratifiedKFold

import shap

import pickle
from joblib import dump, load

Prepare Dataset

In [4]:
# Load dataset
dataset = pd.read_csv(f"data/women_work_data_{year}.csv")
print("Loaded dataset: ", dataset.shape)
dataset.head()
Loaded dataset:  (111398, 26)
Out[4]:
Unnamed: 0 case_id_str line_no country_code cluster_no hh_no state wealth_index hh_religion caste women_weight women_anemic obese_female urban freq_tv age occupation years_edu hh_members no_children_below5 white_collar blue_collar no_occ has_occ year total_children
0 8 1000117.0 2 IA6 10001 17 andaman and nicobar islands middle hindu NaN 0.191636 1.0 0.0 1.0 3.0 23.0 0.0 10.0 2.0 0.0 0.0 0.0 1.0 0.0 2015.0 0.0
1 9 1000120.0 1 IA6 10001 20 andaman and nicobar islands richer hindu none of above 0.191636 0.0 0.0 1.0 3.0 35.0 8.0 8.0 3.0 0.0 0.0 1.0 0.0 1.0 2015.0 2.0
2 11 1000129.0 2 IA6 10001 29 andaman and nicobar islands richest muslim other backward class 0.191636 1.0 0.0 1.0 3.0 46.0 0.0 12.0 3.0 0.0 0.0 0.0 1.0 0.0 2015.0 2.0
3 12 1000129.0 3 IA6 10001 29 andaman and nicobar islands richest muslim other backward class 0.191636 1.0 0.0 1.0 3.0 17.0 0.0 11.0 3.0 0.0 0.0 0.0 1.0 0.0 2015.0 0.0
4 13 1000130.0 2 IA6 10001 30 andaman and nicobar islands richer christian scheduled caste 0.191636 1.0 1.0 1.0 3.0 30.0 0.0 8.0 5.0 0.0 0.0 0.0 1.0 0.0 2015.0 3.0
In [5]:
# See distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0.0    82970
1.0    28428
Name: blue_collar, dtype: int64
In [6]:
# Drop samples where the target is missing
dataset.dropna(axis=0, subset=[target_col, sample_weight_col], inplace=True)
print("Drop missing targets: ", dataset.shape)
Drop missing targets:  (111398, 26)
In [7]:
# Drop samples where age < 21
dataset = dataset[dataset['age'] >= 21]
print("Drop under-21 samples: ", dataset.shape)
Drop under-21 samples:  (86825, 26)
In [8]:
# See new distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0.0    62680
1.0    24145
Name: blue_collar, dtype: int64
In [9]:
# Post-processing

# Group SC/ST castes together
dataset['caste'][dataset['caste'] == 'scheduled caste'] = 'sc/st'
dataset['caste'][dataset['caste'] == 'scheduled tribe'] = 'sc/st'
if year == "200506":
    dataset['caste'][dataset['caste'] == '9'] = "don\'t know"

# Fix naming for General caste
dataset['caste'][dataset['caste'] == 'none of above'] = 'general'

if year == "201516":
    # Convert wealth index from str to int values
    wi_dict = {'poorest': 0, 'poorer': 1, 'middle': 2, 'richer': 3, 'richest': 4}
    dataset['wealth_index'] = [wi_dict[wi] for wi in dataset['wealth_index']]
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:5: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  """
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  # Remove the CWD from sys.path while we load stuff.
In [10]:
# Define feature columns
x_cols_categorical = ['state', 'hh_religion', 'caste']
x_cols_binary = ['urban', 'women_anemic', 'obese_female']
x_cols_numeric = ['age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
x_cols = x_cols_categorical + x_cols_binary + x_cols_numeric
print("Feature columns:\n", x_cols)
Feature columns:
 ['state', 'hh_religion', 'caste', 'urban', 'women_anemic', 'obese_female', 'age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
In [11]:
# Drop samples with missing values in feature columns
dataset.dropna(axis=0, subset=x_cols, inplace=True)
print("Drop missing feature value rows: ", dataset.shape)
Drop missing feature value rows:  (81816, 26)
In [12]:
# Separate target column
targets = dataset[target_col]
# Separate sampling weight column
sample_weights = dataset[sample_weight_col]
# Drop columns which are not part of features
dataset.drop(columns=[col for col in dataset.columns if col not in x_cols], axis=1, inplace=True)
print("Drop extra columns: ", dataset.shape)
Drop extra columns:  (81816, 13)
In [13]:
# Obtain one-hot encodings for the caste column
dataset = pd.get_dummies(dataset, columns=['caste'])
x_cols_categorical.remove('caste')  # Remove 'caste' from categorical variables list
print("Caste to one-hot: ", dataset.shape)
Caste to one-hot:  (81816, 16)
In [14]:
dataset_display = dataset.copy()
dataset_display.columns = ['State', 'Wealth Index', 'Hh. Religion', 'Anemic', 'Obese',
                           'Residence Type', 'Freq. of TV', 'Age', 'Yrs. of Education', 'Hh. Members',
                           'Children Below 5', 'Total Children', 'Unknown Caste',
                           'General Caste', 'OBC Caste', 'Sc/St Caste']
print("Create copy for visualization: ", dataset_display.shape)
dataset_display.head()
Create copy for visualization:  (81816, 16)
Out[14]:
State Wealth Index Hh. Religion Anemic Obese Residence Type Freq. of TV Age Yrs. of Education Hh. Members Children Below 5 Total Children Unknown Caste General Caste OBC Caste Sc/St Caste
1 andaman and nicobar islands 3 hindu 0.0 0.0 1.0 3.0 35.0 8.0 3.0 0.0 2.0 0 1 0 0
2 andaman and nicobar islands 4 muslim 1.0 0.0 1.0 3.0 46.0 12.0 3.0 0.0 2.0 0 0 1 0
4 andaman and nicobar islands 3 christian 1.0 1.0 1.0 3.0 30.0 8.0 5.0 0.0 3.0 0 0 0 1
5 andaman and nicobar islands 3 christian 1.0 0.0 1.0 3.0 21.0 12.0 5.0 0.0 0.0 0 0 0 1
7 andaman and nicobar islands 4 hindu 1.0 1.0 1.0 3.0 40.0 8.0 2.0 0.0 2.0 0 1 0 0
In [15]:
# Obtain integer encodings for other categorical features
for col in x_cols_categorical:
    dataset[col] = pd.factorize(dataset[col])[0]
print("Categoricals to int encodings: ", dataset.shape)
Categoricals to int encodings:  (81816, 16)
In [16]:
dataset.head()
Out[16]:
state wealth_index hh_religion women_anemic obese_female urban freq_tv age years_edu hh_members no_children_below5 total_children caste_don't know caste_general caste_other backward class caste_sc/st
1 0 3 0 0.0 0.0 1.0 3.0 35.0 8.0 3.0 0.0 2.0 0 1 0 0
2 0 4 1 1.0 0.0 1.0 3.0 46.0 12.0 3.0 0.0 2.0 0 0 1 0
4 0 3 2 1.0 1.0 1.0 3.0 30.0 8.0 5.0 0.0 3.0 0 0 0 1
5 0 3 2 1.0 0.0 1.0 3.0 21.0 12.0 5.0 0.0 0.0 0 0 0 1
7 0 4 0 1.0 1.0 1.0 3.0 40.0 8.0 2.0 0.0 2.0 0 1 0 0
In [17]:
# Create Training, Validation and Test sets
X_train, X_test, Y_train, Y_test, W_train, W_test = train_test_split(dataset, targets, sample_weights, test_size=0.05, random_state=random_state, stratify=targets)
# X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(X_train, Y_train, W_train, test_size=0.1)
print("Training set: ", X_train.shape, Y_train.shape, W_train.shape)
# print("Validation set: ", X_val.shape, Y_val.shape, W_val.shape)
print("Test set: ", X_test.shape, Y_test.shape, W_test.shape)
train_cw = compute_class_weight("balanced", classes=np.unique(Y_train), y=Y_train)
print("Class weights: ", train_cw)
Training set:  (77725, 16) (77725,) (77725,)
Test set:  (4091, 16) (4091,) (4091,)
Class weights:  [0.69719775 1.76776292]

Build LightGBM Classifier

In [18]:
# # Define LightGBM Classifier
# model = LGBMClassifier(boosting_type='gbdt', 
#                        feature_fraction=0.8,  
#                        learning_rate=0.01,
#                        max_bins=64,
#                        max_depth=-1,
#                        min_child_weight=0.001,
#                        min_data_in_leaf=50,
#                        min_split_gain=0.0,
#                        num_iterations=1000,
#                        num_leaves=64,
#                        reg_alpha=0,
#                        reg_lambda=1,
#                        subsample_for_bin=200000,
#                        is_unbalance=True,
#                        random_state=random_state, 
#                        n_jobs=n_jobs_clf, 
#                        silent=True, 
#                        importance_type='split')
In [19]:
# # Fit model on training set
# model.fit(X_train, Y_train, sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])
In [20]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [21]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-model.joblib')
# del model
In [22]:
# # Define hyperparameter grid
# param_grid = {
#     'num_leaves': [8, 32, 64],
#     'min_data_in_leaf': [10, 20, 50],
#     'max_depth': [-1], 
#     'learning_rate': [0.01, 0.1], 
#     'num_iterations': [1000, 3000, 5000], 
#     'subsample_for_bin': [200000],
#     'min_split_gain': [0.0], 
#     'min_child_weight': [0.001],
#     'feature_fraction': [0.8, 1.0], 
#     'reg_alpha': [0], 
#     'reg_lambda': [0, 1],
#     'max_bin': [64, 128, 255]
# }
In [23]:
# # Define LightGBM Classifier
# clf = LGBMClassifier(boosting_type='gbdt',
#                      objective='binary', 
#                      is_unbalance=True,
#                      random_state=random_state,
#                      n_jobs=n_jobs_clf, 
#                      silent=True, 
#                      importance_type='split')

# # Define K-fold cross validation splitter
# kfold = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

# # Perform grid search
# model = GridSearchCV(clf, param_grid=param_grid, scoring='f1', n_jobs=n_jobs_cv, cv=kfold, refit=True, verbose=3)
# model.fit(X_train, Y_train, 
#           sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])

# print('\n All results:')
# print(model.cv_results_)
# print('\n Best estimator:')
# print(model.best_estimator_)
# print('\n Best hyperparameters:')
# print(model.best_params_)
In [24]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions, average='micro'))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [25]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-gridsearch.joblib')
# del model

Load LightGBM Classifier

In [26]:
# model = load(f'models/{target_col}-{year}-model.joblib')
model = load(f'models/{target_col}-{year}-gridsearch.joblib').best_estimator_
In [27]:
# Sanity check: Make predictions on Test set
predictions = model.predict(X_test)
print(accuracy_score(Y_test, predictions))
print(f1_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
print(classification_report(Y_test, predictions))
0.6814959667562943
0.5511539786427833
[[1988  946]
 [ 357  800]]
              precision    recall  f1-score   support

         0.0       0.85      0.68      0.75      2934
         1.0       0.46      0.69      0.55      1157

   micro avg       0.68      0.68      0.68      4091
   macro avg       0.65      0.68      0.65      4091
weighted avg       0.74      0.68      0.70      4091

In [28]:
# Overfitting check: Make predictions on Train set
predictions = model.predict(X_train)
print(accuracy_score(Y_train, predictions))
print(f1_score(Y_train, predictions))
print(confusion_matrix(Y_train, predictions))
print(classification_report(Y_train, predictions))
0.7163460919909939
0.5963049090875798
[[39395 16346]
 [ 5701 16283]]
              precision    recall  f1-score   support

         0.0       0.87      0.71      0.78     55741
         1.0       0.50      0.74      0.60     21984

   micro avg       0.72      0.72      0.72     77725
   macro avg       0.69      0.72      0.69     77725
weighted avg       0.77      0.72      0.73     77725


Visualizations/Explainations

Note that these plot just explain how the XGBoost model works, not nessecarily how reality works. Since the XGBoost model is trained from observational data, it is not nessecarily a causal model, and so just because changing a factor makes the model's prediction of winning go up, does not always mean it will raise your actual chances.

In [29]:
# print the JS visualization code to the notebook
shap.initjs()

What makes a measure of feature importance good or bad?

  1. Consistency: Whenever we change a model such that it relies more on a feature, then the attributed importance for that feature should not decrease.
  2. Accuracy. The sum of all the feature importances should sum up to the total importance of the model. (For example if importance is measured by the R² value then the attribution to each feature should sum to the R² of the full model)

If consistency fails to hold, then we can’t compare the attributed feature importances between any two models, because then having a higher assigned attribution doesn’t mean the model actually relies more on that feature.

If accuracy fails to hold then we don’t know how the attributions of each feature combine to represent the output of the whole model. We can’t just normalize the attributions after the method is done since this might break the consistency of the method.

Using Tree SHAP for interpretting the model

In [30]:
explainer = shap.TreeExplainer(model)
# shap_values = explainer.shap_values(dataset)
shap_values = pickle.load(open(f'res/{target_col}-{year}-shapvals.obj', 'rb'))
In [31]:
# Visualize a single prediction
shap.force_plot(explainer.expected_value, shap_values[0,:], dataset_display.iloc[0,:])
Out[31]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

If we take many explanations such as the one shown above, rotate them 90 degrees, and then stack them horizontally, we can see explanations for an entire dataset (in the notebook this plot is interactive):

In [32]:
# Visualize many predictions
subsample = np.random.choice(len(dataset), 1000)  # Take random sub-sample
shap.force_plot(explainer.expected_value, shap_values[subsample,:], dataset_display.iloc[subsample,:])
Out[32]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Summary Plots

In [33]:
for col, sv in zip(dataset.columns, np.abs(shap_values).mean(0)):
    print(f"{col} - {sv}")
state - 0.3334489747560986
wealth_index - 0.35427714391447723
hh_religion - 0.0950417809003661
women_anemic - 0.012567775101135376
obese_female - 0.07499836179360791
urban - 0.10343178383604486
freq_tv - 0.04726483737390738
age - 0.17220589031757438
years_edu - 0.3871120531216245
hh_members - 0.04451344216268129
no_children_below5 - 0.14378756152336108
total_children - 0.05632185140851888
caste_don't know - 0.0007036416854973273
caste_general - 0.07180249029912829
caste_other backward class - 0.009937315407254633
caste_sc/st - 0.1244942904077975
In [34]:
shap.summary_plot(shap_values, dataset, plot_type="bar")

The above figure shows the global mean(|Tree SHAP|) method applied to our model.

The x-axis is essentially the average magnitude change in model output when a feature is “hidden” from the model (for this model the output has log-odds units). “Hidden” means integrating the variable out of the model. Since the impact of hiding a feature changes depending on what other features are also hidden, Shapley values are used to enforce consistency and accuracy.

However, since we now have individualized explanations for every person in our dataset, to get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low):

In [35]:
shap.summary_plot(shap_values, dataset_display)
  • Every person has one dot on each row.
  • The x position of the dot is the impact of that feature on the model’s prediction for the person.
  • The color of the dot represents the value of that feature for the customer. Categorical variables are colored grey.
  • Dots that don’t fit on the row pile up to show density (since our dataset is large).
  • Since the XGBoost model has a logistic loss the x-axis has units of log-odds (Tree SHAP explains the change in the margin output of the model).

How to use this: We can make analysis similar to the blog post for interpretting our models.


SHAP Dependence Plots

Next, to understand how a single feature effects the output of the model we can plot the SHAP value of that feature vs. the value of the feature for all the examples in a dataset. SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples.

SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions. One the benefits of SHAP dependence plots over traditional partial dependence plots is this ability to distigush between between models with and without interaction terms. In other words, SHAP dependence plots give an idea of the magnitude of the interaction terms through the vertical variance of the scatter plot at a given feature value.

Good example of using Dependency Plots: https://slundberg.github.io/shap/notebooks/League%20of%20Legends%20Win%20Prediction%20with%20XGBoost.html

Plots for 'age'

In [36]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children'),
         ('hh_religion', 'age'),
         ('state', 'age')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: age
Feature: state, Interaction Feature: age

Plots for 'wealth_index'

In [37]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children'),
         ('hh_religion', 'wealth_index'),
         ('state', 'wealth_index')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: wealth_index
Feature: state, Interaction Feature: wealth_index

Plots for 'years_edu'

In [38]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children'),
         ('hh_religion', 'years_edu'),
         ('state', 'years_edu')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: years_edu
Feature: state, Interaction Feature: years_edu

Plots for 'caste_sc/st'

In [39]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children'),
         ('hh_religion', 'caste_sc/st'),
         ('state', 'caste_sc/st')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_sc/st
Feature: state, Interaction Feature: caste_sc/st

Plots for 'caste_general'

In [40]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
         ('hh_religion', 'caste_general'),
         ('state', 'caste_general')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_general
Feature: state, Interaction Feature: caste_general

Visualizing Bar/Summary plots split by age bins

In [41]:
bins = [(21,25), (26,30), (31,35), (36,40), (41,45), (46,50)]

for low, high in bins:
    # Sample dataset by age range
    dataset_sample = dataset[(dataset.age > low) & (dataset.age <= high)]
    dataset_display_sample = dataset_display[(dataset.age > low) & (dataset.age <= high)]
    targets_sample = targets[(dataset.age > low) & (dataset.age <= high)]
    shap_values_sample = shap_values[(dataset.age > low) & (dataset.age <= high)]
    
    print("\nAge Range: {} - {} years".format(low, high))
    print("Sample size: {}\n".format(len(dataset_sample)))
    
    for col, sv in zip(dataset_sample.columns, np.abs(shap_values_sample).mean(0)):
        print(f"{col} - {sv}")
    
    # Summary plots
    shap.summary_plot(shap_values_sample, dataset_sample, plot_type="bar")
    shap.summary_plot(shap_values_sample, dataset_display_sample)
Age Range: 21 - 25 years
Sample size: 14034

state - 0.3026053978829125
wealth_index - 0.3537775609239908
hh_religion - 0.07386638490990269
women_anemic - 0.01091456013302559
obese_female - 0.04798707420412837
urban - 0.10450015428732949
freq_tv - 0.05950145631349227
age - 0.2921491142483017
years_edu - 0.41488667575833266
hh_members - 0.04731932949195551
no_children_below5 - 0.17364774359826188
total_children - 0.0716685530182734
caste_don't know - 0.0010235193591526044
caste_general - 0.05749155550858089
caste_other backward class - 0.009293519675678298
caste_sc/st - 0.11425547358654378
Age Range: 26 - 30 years
Sample size: 13706

state - 0.3301005487576073
wealth_index - 0.3559618625143623
hh_religion - 0.08843587123288825
women_anemic - 0.010553900073406603
obese_female - 0.06680178648320777
urban - 0.11841314205420557
freq_tv - 0.04509595589104865
age - 0.05241573640676622
years_edu - 0.3804742687473381
hh_members - 0.0483145100749847
no_children_below5 - 0.1813857076838349
total_children - 0.05395635534227774
caste_don't know - 0.0008964437195014538
caste_general - 0.0668907119811839
caste_other backward class - 0.010820102433579098
caste_sc/st - 0.11453339102193717
Age Range: 31 - 35 years
Sample size: 12526

state - 0.3495430766049304
wealth_index - 0.3446885510422806
hh_religion - 0.10910120861793177
women_anemic - 0.014695555729752053
obese_female - 0.0818960489283079
urban - 0.10491647564260433
freq_tv - 0.04264023802172609
age - 0.1632586271439636
years_edu - 0.36325943343832023
hh_members - 0.043088269198626634
no_children_below5 - 0.15763003437466738
total_children - 0.04810629541762849
caste_don't know - 0.0005482220241432459
caste_general - 0.07360421480461667
caste_other backward class - 0.010174978709587166
caste_sc/st - 0.12530036668076078
Age Range: 36 - 40 years
Sample size: 11283

state - 0.35533997395717315
wealth_index - 0.3515420946887581
hh_religion - 0.11159481736499423
women_anemic - 0.013685394306013272
obese_female - 0.08541778007955814
urban - 0.09559550153723782
freq_tv - 0.03963295076832347
age - 0.19353676067848535
years_edu - 0.3719416151959908
hh_members - 0.039788352967780535
no_children_below5 - 0.11050393541868181
total_children - 0.050371944714949594
caste_don't know - 0.0004529275029945085
caste_general - 0.07536526944052115
caste_other backward class - 0.010160717386944206
caste_sc/st - 0.13394603948676229
Age Range: 41 - 45 years
Sample size: 10087

state - 0.3461418985666017
wealth_index - 0.3602537679021365
hh_religion - 0.10573838525824424
women_anemic - 0.013157365871976082
obese_female - 0.09482486031549842
urban - 0.08944621807109131
freq_tv - 0.04299324056491018
age - 0.17777499975475158
years_edu - 0.38158619537433097
hh_members - 0.03934174547448142
no_children_below5 - 0.09625235503532308
total_children - 0.05481128314886503
caste_don't know - 0.00045174973389167924
caste_general - 0.08695508286191994
caste_other backward class - 0.009855056052289715
caste_sc/st - 0.13534962286887656
Age Range: 46 - 50 years
Sample size: 6013

state - 0.34287405764358186
wealth_index - 0.37567032942592365
hh_religion - 0.09396219614902841
women_anemic - 0.012471407232459696
obese_female - 0.09858026227732072
urban - 0.1045663101842607
freq_tv - 0.05046851249546494
age - 0.09521014695390322
years_edu - 0.4033645344310303
hh_members - 0.05136535071006243
no_children_below5 - 0.09934977543990349
total_children - 0.05099293413575399
caste_don't know - 0.000469501339851672
caste_general - 0.08300296169207906
caste_other backward class - 0.009493019576588923
caste_sc/st - 0.13663964760442418

SHAP Interaction Values

SHAP interaction values are a generalization of SHAP values to higher order interactions.

The model returns a matrix for every prediction, where the main effects are on the diagonal and the interaction effects are off-diagonal. The main effects are similar to the SHAP values you would get for a linear model, and the interaction effects captures all the higher-order interactions are divide them up among the pairwise interaction terms.

Note that the sum of the entire interaction matrix is the difference between the model's current output and expected output, and so the interaction effects on the off-diagonal are split in half (since there are two of each). When plotting interaction effects the SHAP package automatically multiplies the off-diagonal values by two to get the full interaction effect.

In [42]:
# Sample from dataset based on sample weights
dataset_ss = dataset.sample(10000, weights=sample_weights, random_state=random_state)
print(dataset_ss.shape)
dataset_display_ss = dataset_display.loc[dataset_ss.index]
print(dataset_display_ss.shape)
(10000, 16)
(10000, 16)
In [43]:
# Compute SHAP interaction values (time consuming)
# shap_interaction_values = explainer.shap_interaction_values(dataset_ss)
shap_interaction_values = pickle.load(open(f'res/{target_col}-{year}-shapints.obj', 'rb'))
In [44]:
shap.summary_plot(shap_interaction_values, dataset_display_ss, max_display=15)

Heatmap of SHAP Interaction Values

In [45]:
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i,i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds,:][:,inds]
pl.figure(figsize=(12,12))
pl.imshow(tmp2)
pl.yticks(range(tmp2.shape[0]), dataset_display_ss.columns[inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), dataset_display_ss.columns[inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.show()

SHAP Interaction Value Dependence Plots

Running a dependence plot on the SHAP interaction values a allows us to separately observe the main effects and the interaction effects.

Below we plot the main effects for age and some of the interaction effects for age. It is informative to compare the main effects plot of age with the earlier SHAP value plot for age. The main effects plot has no vertical dispersion because the interaction effects are all captured in the off-diagonal terms.

Good example of how to infer interesting stuff from interaction values: https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html

In [46]:
shap.dependence_plot(
    ("age", "age"), 
    shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)

Now we plot the interaction effects involving age (and other features after that). These effects capture all of the vertical dispersion that was present in the original SHAP plot but is missing from the main effects plot above.

Plots for 'age'

In [47]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children

Plots for 'wealth_index'

In [48]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children

Plots for 'years_edu'

In [49]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children

Plots for 'caste_sc/st'

In [50]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children

Plots for 'caste_general'

In [51]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children